# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

##
# @file   NonLinearPlace.py
# @author Yibo Lin
# @date   Jul 2018
# @brief  Nonlinear placement engine to be called with parameters and placement database
#

import os
import sys
import time
import pickle
import numpy as np
import logging
import torch
from torch.optim.lr_scheduler import ExponentialLR
import gzip
import copy
import math
import matplotlib.pyplot as plt

if sys.version_info[0] < 3:
    import cPickle as pickle
else:
    import _pickle as pickle
import dreamplace.BasicPlace as BasicPlace
import dreamplace.PlaceObj as PlaceObj
import dreamplace.NesterovAcceleratedGradientOptimizer as NesterovAcceleratedGradientOptimizer
import dreamplace.EvalMetrics as EvalMetrics
import pdb
import dreamplace.ops.fence_region.fence_region as fence_region
import dreamplace.ops.place_io.place_io_cpp as place_io_cpp


class NonLinearPlace(BasicPlace.BasicPlace):
    """
    @brief Nonlinear placement engine.
    It takes parameters and placement database and runs placement flow.
    """

    def __init__(self, params, placedb):
        """
        @brief initialization.
        @param params parameters
        @param placedb placement database
        """
        super(NonLinearPlace, self).__init__(params, placedb)

    def __call__(self, params, placedb):
        """
        @brief Top API to solve placement.
        @param params parameters
        @param placedb placement database
        """
        iteration = 0
        all_metrics = []
        lrs = []
        scheduler = None
        plot_frequency = 100

        # global placement
        if params.global_place_flag:
            # global placement may run in multiple stages according to user specification
            for global_place_params in params.global_place_stages:

                # we formulate each stage as a 3-nested optimization problem
                # f_gamma(g_density(h(x) ; density weight) ; gamma)
                # Lgamma      Llambda        Lsub
                # When optimizing an inner problem, the outer parameters are fixed.
                # This is a generalization to the eplace/RePlAce approach

                # As global placement may easily diverge, we record the position of best overflow
                best_metric = [None]
                best_pos = [None]

                if params.gpu:
                    torch.cuda.synchronize()
                tt = time.time()
                # construct model and optimizer
                density_weight = 0.0
                # construct placement model
                model = PlaceObj.PlaceObj(
                    density_weight,
                    params,
                    placedb,
                    self.data_collections,
                    self.op_collections,
                    global_place_params,
                ).to(self.data_collections.pos[0].device)
                optimizer_name = global_place_params["optimizer"]

                # determine optimizer
                if optimizer_name.lower() == "adam":
                    optimizer = torch.optim.Adam(self.parameters(), lr=0)
                elif optimizer_name.lower() == "sgd":
                    optimizer = torch.optim.SGD(self.parameters(), lr=0)
                elif optimizer_name.lower() == "sgd_momentum":
                    optimizer = torch.optim.SGD(
                        self.parameters(), lr=0, momentum=0.9, nesterov=False
                    )
                elif optimizer_name.lower() == "sgd_nesterov":
                    optimizer = torch.optim.SGD(
                        self.parameters(), lr=0, momentum=0.9, nesterov=True
                    )
                elif optimizer_name.lower() == "nesterov":
                    optimizer = NesterovAcceleratedGradientOptimizer.NesterovAcceleratedGradientOptimizer(
                        self.parameters(),
                        lr=0,
                        obj_and_grad_fn=model.obj_and_grad_fn,
                        constraint_fn=self.op_collections.move_boundary_op,
                    )
                else:
                    assert 0, "unknown optimizer %s" % (optimizer_name)

                logging.info("use %s optimizer" % (optimizer_name))
                model.train()
                # defining evaluation ops
                eval_ops = {
                    # "wirelength" : self.op_collections.wirelength_op,
                    # "density" : self.op_collections.density_op,
                    # "objective" : model.obj_fn,
                    # "weight_hpwl": self.op_collections.weight_hpwl_op,
                    "hpwl": self.op_collections.hpwl_op,
                    "overflow": self.op_collections.density_overflow_op,
                }
                if params.routability_opt_flag:
                    eval_ops.update(
                        {
                            "route_utilization": self.op_collections.route_utilization_map_op,
                            "pin_utilization": self.op_collections.pin_utilization_map_op,
                        }
                    )
                if params.macro_overlap_flag:
                    eval_ops.update(
                        {"macro_overlap": self.op_collections.macro_overlap_op}
                    )
                if len(placedb.regions) > 0:
                    eval_ops.update(
                        {
                            "density": self.op_collections.fence_region_density_merged_op,
                            "overflow": self.op_collections.fence_region_density_overflow_merged_op,
                            "goverflow": self.op_collections.density_overflow_op,
                        }
                    )

                # a function to initialize learning rate
                def initialize_learning_rate(pos):
                    learning_rate = model.estimate_initial_learning_rate(
                        pos, global_place_params["learning_rate"]
                    )
                    # update learning rate
                    for param_group in optimizer.param_groups:
                        param_group["lr"] = learning_rate.data
                    lrs.append(optimizer.param_groups[0]["lr"])

                if iteration == 0:
                    if params.gp_noise_ratio > 0.0 and params.random_center_init_flag:
                        logging.info("add %g%% noise" % (params.gp_noise_ratio * 100))
                        model.op_collections.noise_op(
                            model.data_collections.pos[0], params.gp_noise_ratio
                        )
                    initialize_learning_rate(model.data_collections.pos[0])
                    if (
                        optimizer_name.lower()
                        in [
                            "sgd",
                            "adam",
                            "sgd_momentum",
                            "sgd_nesterov",
                        ]
                        and "learning_rate_decay" in global_place_params
                    ):
                        scheduler = ExponentialLR(
                            optimizer,
                            gamma=global_place_params["learning_rate_decay"],
                        )
                # the state must be saved after setting learning rate
                initial_state = copy.deepcopy(optimizer.state_dict())

                if params.gpu:
                    torch.cuda.synchronize()
                logging.info(
                    "%s initialization takes %g seconds"
                    % (optimizer_name, (time.time() - tt))
                )

                # as nesterov requires line search, we cannot follow the convention of other solvers
                if optimizer_name.lower() in {
                    "sgd",
                    "adam",
                    "sgd_momentum",
                    "sgd_nesterov",
                }:
                    model.obj_and_grad_fn(model.data_collections.pos[0])
                elif optimizer_name.lower() != "nesterov":
                    assert 0, "unsupported optimizer %s" % (optimizer_name)

                # stopping criteria
                def Lgamma_stop_criterion(Lgamma_step, metrics, stop_mask=None):
                    with torch.no_grad():
                        if len(metrics) > 1:
                            cur_metric = metrics[-1][-1][-1]
                            prev_metric = metrics[-2][-1][-1]
                            ### update stop mask for each fence region
                            # if(stop_mask is not None):
                            #     stop_mask.copy_(cur_metric.overflow < params.stop_overflow)

                            if Lgamma_step > 100 and (
                                ### for fence region, the outer cell overflow decides the stopping of GP
                                (
                                    cur_metric.overflow[-1] < params.stop_overflow
                                    and cur_metric.hpwl > prev_metric.hpwl
                                )
                                or cur_metric.max_density[-1] < params.target_density
                            ):
                                logging.info(
                                    "Lgamma stopping criteria: %d > 100 and (( %g < 0.1 and %g > %g ) or %g < 1.0)"
                                    % (
                                        Lgamma_step,
                                        cur_metric.overflow[-1],
                                        cur_metric.hpwl,
                                        prev_metric.hpwl,
                                        cur_metric.max_density[-1],
                                    )
                                )
                                return True
                            if (
                                len(placedb.regions) > 0
                                and model.update_mask.sum() == 0
                            ):
                                logging.info(
                                    "All regions stop updating, finish global placement"
                                )
                                return True
                        # a heuristic to detect divergence and stop early
                        """
                        # remove another heuristics
                        if len(metrics) > 50:
                            cur_metric = metrics[-1][-1][-1]
                            prev_metric = metrics[-50][-1][-1]
                            # record HPWL and overflow increase, and check divergence
                            if (
                                cur_metric.overflow[-1] > prev_metric.overflow[-1]
                                and cur_metric.hpwl > best_metric[0].hpwl * 2
                            ):
                                return True
                        """
                        # stop when there is nan or inf numbers, because it will never get better
                        if torch.isinf(metrics[-1][-1][-1].objective) or torch.isnan(
                            metrics[-1][-1][-1].objective
                        ):
                            return True

                        return False

                def Llambda_stop_criterion(
                    Lgamma_step, Llambda_density_weight_step, metrics
                ):
                    with torch.no_grad():
                        if len(metrics) > 1:
                            cur_metric = metrics[-1][-1]
                            prev_metric = metrics[-2][-1]
                            ### for fence regions, the outer cell overflow and max_density decides whether to stop
                            if (
                                cur_metric.overflow[-1] < params.stop_overflow
                                and cur_metric.hpwl > prev_metric.hpwl
                            ) or cur_metric.max_density[-1] < 1.0:
                                logging.info(
                                    "Llambda stopping criteria: %d and (( %g < 0.1 and %g > %g ) or %g < 1.0)"
                                    % (
                                        Llambda_density_weight_step,
                                        cur_metric.overflow[-1],
                                        cur_metric.hpwl,
                                        prev_metric.hpwl,
                                        cur_metric.max_density[-1],
                                    )
                                )
                                return True
                    return False

                # use a moving average window for stopping criteria, for an example window of 3
                # 0, 1, 2, 3, 4, 5, 6
                #    window2
                #             window1
                moving_avg_window = max(min(model.Lsub_iteration // 2, 3), 1)

                def Lsub_stop_criterion(
                    Lgamma_step, Llambda_density_weight_step, Lsub_step, metrics
                ):
                    with torch.no_grad():
                        if len(metrics) >= moving_avg_window * 2:
                            cur_avg_obj = 0
                            prev_avg_obj = 0
                            for i in range(moving_avg_window):
                                cur_avg_obj += metrics[-1 - i].objective
                                prev_avg_obj += metrics[
                                    -1 - moving_avg_window - i
                                ].objective
                            cur_avg_obj /= moving_avg_window
                            prev_avg_obj /= moving_avg_window
                            threshold = 0.999
                            if cur_avg_obj >= prev_avg_obj * threshold:
                                logging.info(
                                    "Lsub stopping criteria: %d and %g > %g * %g"
                                    % (Lsub_step, cur_avg_obj, prev_avg_obj, threshold)
                                )
                                return True
                    return False

                def one_descent_step(
                    Lgamma_step,
                    Llambda_density_weight_step,
                    Lsub_step,
                    iteration,
                    metrics,
                    stop_mask=None,
                ):
                    t0 = time.time()

                    # metric for this iteration
                    cur_metric = EvalMetrics.EvalMetrics(
                        iteration, (Lgamma_step, Llambda_density_weight_step, Lsub_step)
                    )
                    cur_metric.gamma = model.gamma.data
                    cur_metric.density_weight = model.density_weight.data
                    if params.macro_overlap_flag:
                        cur_metric.macro_overlap_weight = (
                            model.macro_overlap_weight.data
                        )
                    metrics.append(cur_metric)
                    pos = model.data_collections.pos[0]

                    # move any out-of-bound cell back to placement region
                    self.op_collections.move_boundary_op(pos)

                    # handle multiple density weights for multi-electric field
                    if torch.eq(model.density_weight.mean(), 0.0):
                        model.initialize_density_weight(params, placedb)
                        if model.density_weight.size(0) == 1:
                            logging.info(
                                "density_weight = %.6E" % (model.density_weight.data)
                            )
                        else:
                            logging.info(
                                "density_weight = [%s]"
                                % ", ".join(
                                    [
                                        "%.3E" % i
                                        for i in model.density_weight.cpu()
                                        .numpy()
                                        .tolist()
                                    ]
                                )
                            )

                    # initialize macro overlap weight
                    if params.macro_overlap_flag and torch.eq(
                        model.macro_overlap_weight, 0.0
                    ):
                        model.initialize_macro_overlap_weight(params, placedb)
                        logging.info(
                            "macro_overlap_weight = %.6E"
                            % (model.macro_overlap_weight.data)
                        )

                    optimizer.zero_grad()

                    # t1 = time.time()
                    cur_metric.evaluate(placedb, eval_ops, pos, model.data_collections)
                    model.overflow = cur_metric.overflow.data.clone()
                    # logging.debug("evaluation %.3f ms" % ((time.time()-t1)*1000))
                    # t2 = time.time()

                    # as nesterov requires line search, we cannot follow the convention of other solvers
                    if optimizer_name.lower() in [
                        "sgd",
                        "adam",
                        "sgd_momentum",
                        "sgd_nesterov",
                    ]:
                        obj, grad = model.obj_and_grad_fn(pos)
                        cur_metric.objective = obj.data.clone()
                    elif optimizer_name.lower() != "nesterov":
                        assert 0, "unsupported optimizer %s" % (optimizer_name)

                    # plot placement
                    if params.plot_flag and iteration % plot_frequency == 0:
                        cur_pos = self.pos[0].data.clone().cpu().numpy()
                        self.plot(params, placedb, iteration, cur_pos)

                    #### stop updating fence regions that are marked stop, exclude the outer cell !
                    t3 = time.time()
                    if model.update_mask is not None:
                        pos_bk = pos.data.clone()
                        optimizer.step()

                        for region_id, fence_region_update_flag in enumerate(
                            model.update_mask
                        ):
                            if fence_region_update_flag == 0:
                                ### don't update cell location in that region
                                mask = self.op_collections.fence_region_density_ops[
                                    region_id
                                ].pos_mask
                                pos.data.masked_scatter_(mask, pos_bk[mask])
                    else:
                        optimizer.step()

                    logging.info("optimizer step %.3f ms" % ((time.time() - t3) * 1000))

                    # nesterov has already computed the objective of the next step
                    if optimizer_name.lower() == "nesterov":
                        cur_metric.objective = optimizer.param_groups[0]["obj_k_1"][
                            0
                        ].data.clone()

                    # actually reports the metric before step
                    logging.info(cur_metric)
                    # record the best outer cell overflow
                    if (
                        best_metric[0] is None
                        or best_metric[0].overflow[-1] > cur_metric.overflow[-1]
                    ):
                        best_metric[0] = cur_metric
                        if best_pos[0] is None:
                            best_pos[0] = self.pos[0].data.clone()
                        else:
                            best_pos[0].data.copy_(self.pos[0].data)

                    logging.info("full step %.3f ms" % ((time.time() - t0) * 1000))

                def check_plateau(x, window=10, threshold=0.001):
                    if len(x) < window:
                        return False
                    x = x[-window:]
                    return (np.max(x) - np.min(x)) / np.mean(x) < threshold

                def check_divergence(x, window=50, threshold=0.05):
                    if len(x) < window or best_metric[0] is None:
                        return False
                    x = np.array(x[-window:])
                    overflow_mean = np.mean(x[:, 1])
                    overflow_diff = np.maximum(0, np.sign(x[1:, 1] - x[:-1, 1])).astype(
                        np.float32
                    )
                    overflow_diff = np.sum(overflow_diff) / overflow_diff.shape[0]
                    overflow_range = np.max(x[:, 1]) - np.min(x[:, 1])
                    wl_mean = np.mean(x[:, 0])
                    wl_ratio, overflow_ratio = (
                        wl_mean - best_metric[0].hpwl.item()
                    ) / best_metric[0].hpwl.item(), (
                        overflow_mean
                        - max(params.stop_overflow, best_metric[0].overflow.item())
                    ) / best_metric[
                        0
                    ].overflow.item()
                    if wl_ratio > threshold * 1.2:
                        if overflow_ratio > threshold:
                            logging.warn(
                                f"Divergence detected: overflow increases too much than best overflow ({overflow_ratio:.4f} > {threshold:.4f})"
                            )
                            return True
                        elif overflow_range / overflow_mean < threshold:
                            logging.warn(
                                f"Divergence detected: overflow plateau ({overflow_range/overflow_mean:.4f} < {threshold:.4f})"
                            )
                            return True
                        elif overflow_diff > 0.6:
                            logging.warn(
                                f"Divergence detected: overflow fluctuate too frequently ({overflow_diff:.2f} > 0.6)"
                            )
                            return True
                        else:
                            return False
                    else:
                        return False

                def entropy_injection(
                    pos,
                    placedb,
                    shrink_factor=1,
                    noise_intensity=1,
                    mode="random",
                    iteration=1,
                ):
                    if mode == "random":
                        # print(pos[: placedb.num_movable_nodes].mean())
                        xc = pos[: placedb.num_movable_nodes].data.mean()
                        yc = pos.data[
                            placedb.num_nodes : placedb.num_nodes
                            + placedb.num_movable_nodes
                        ].mean()
                        num_movable_nodes = placedb.num_movable_nodes
                        num_nodes = placedb.num_nodes
                        num_filler_nodes = placedb.num_filler_nodes
                        num_fixed_nodes = (
                            num_nodes - num_movable_nodes - num_filler_nodes
                        )

                        fixed_pos_x = pos.data[
                            num_movable_nodes : num_movable_nodes + num_fixed_nodes
                        ].clone()
                        fixed_pos_y = pos.data[
                            num_nodes
                            + num_movable_nodes : num_nodes
                            + num_movable_nodes
                            + num_fixed_nodes
                        ].clone()
                        if shrink_factor != 1:
                            pos.data[:num_nodes] = (
                                pos.data[:num_nodes] - xc
                            ) * shrink_factor + xc
                            pos.data[num_nodes:] = (
                                pos.data[num_nodes:] - yc
                            ) * shrink_factor + yc
                        if noise_intensity > 0.01:
                            # pos.data.add_(noise_intensity * torch.rand(num_nodes*2, device=pos.device).sub_(0.5))
                            pos.data.add_(
                                noise_intensity
                                * torch.randn(num_nodes * 2, device=pos.device)
                            )

                        pos.data[
                            num_movable_nodes : num_movable_nodes + num_fixed_nodes
                        ] = fixed_pos_x
                        pos.data[
                            num_nodes
                            + num_movable_nodes : num_nodes
                            + num_movable_nodes
                            + num_fixed_nodes
                        ] = fixed_pos_y
                        # print(pos[: placedb.num_movable_nodes].mean())
                    else:
                        raise NotImplementedError

                Lgamma_metrics = all_metrics

                if params.routability_opt_flag:
                    adjust_area_flag = True
                    adjust_route_area_flag = (
                        params.adjust_nctugr_area_flag or params.adjust_rudy_area_flag
                    )
                    adjust_pin_area_flag = params.adjust_pin_area_flag
                    num_area_adjust = 0

                Llambda_flat_iteration = 0

                ### preparation for self-adaptive divergence check
                overflow_list = [1]
                divergence_list = []
                min_perturb_interval = 50
                stop_placement = 0
                last_perturb_iter = -min_perturb_interval
                perturb_counter = 0

                # reduce iterations
                model.Lgamma_iteration /= (
                    model.Llambda_density_weight_iteration * model.Lsub_iteration
                )
                model.Lgamma_iteration = int(model.Lgamma_iteration)

                for Lgamma_step in range(model.Lgamma_iteration):
                    Lgamma_metrics.append([])
                    Llambda_metrics = Lgamma_metrics[-1]
                    for Llambda_density_weight_step in range(
                        model.Llambda_density_weight_iteration
                    ):
                        Llambda_metrics.append([])
                        Lsub_metrics = Llambda_metrics[-1]
                        for Lsub_step in range(model.Lsub_iteration):
                            ## divergence threshold should decrease as overflow decreases
                            ## only detect divergence when overflow is relatively low but not too low
                            """
                            # this heuristics makes placement unstable
                            if (
                                len(placedb.regions) == 0
                                and params.stop_overflow * 1.1 < overflow_list[-1] < params.stop_overflow * 4
                                and check_divergence(
                                    divergence_list, window=3, threshold=0.01 * overflow_list[-1]
                                )
                            ):
                                self.pos[0].data.copy_(best_pos[0].data)
                                stop_placement = 1

                                logging.error(
                                    "possible DIVERGENCE detected, roll back to the best position recorded"
                                )
                            """

                            one_descent_step(
                                Lgamma_step,
                                Llambda_density_weight_step,
                                Lsub_step,
                                iteration,
                                Lsub_metrics,
                            )

                            if len(placedb.regions) == 0:
                                overflow_list.append(
                                    Llambda_metrics[-1][-1].overflow.data.item()
                                )
                                divergence_list.append(
                                    [
                                        Llambda_metrics[-1][-1].hpwl.data.item(),
                                        Llambda_metrics[-1][-1].overflow.data.item(),
                                    ]
                                )

                            ## quadratic penalty and entropy injection
                            """
                            # This heuristics makes placement unstable
                            if (
                                len(placedb.regions) == 0
                                and iteration - last_perturb_iter > min_perturb_interval
                                and check_plateau(overflow_list, window=15, threshold=0.001)
                            ):
                                if overflow_list[-1] > 0.9:  # stuck at high overflow
                                    model.quad_penalty = True
                                    model.density_factor *= 2
                                    logging.info(
                                        f"Stuck at early stage. Turn on quadratic penalty with double density factor to accelerate convergence"
                                    )
                                    if overflow_list[-1] > 0.95:  # stuck at very high overflow
                                        noise_intensity = min(
                                            max(40 + (120 - 40) * (overflow_list[-1] - 0.95) * 10, 40), 90
                                        )
                                        entropy_injection(
                                            self.pos[0],
                                            placedb,
                                            shrink_factor=0.996,
                                            noise_intensity=noise_intensity,
                                            mode="random",
                                        )
                                        logging.info(
                                            f"Stuck at very early stage. Turn on entropy injection with noise intensity = {noise_intensity} to help convergence"
                                        )
                                    last_perturb_iter = iteration
                                    perturb_counter += 1
                            """
                            iteration += 1
                            # stopping criteria
                            if Lsub_stop_criterion(
                                Lgamma_step,
                                Llambda_density_weight_step,
                                Lsub_step,
                                Lsub_metrics,
                            ):
                                break
                        Llambda_flat_iteration += 1

                        # update density weight
                        if Llambda_flat_iteration > 1:
                            model.op_collections.update_density_weight_op(
                                Llambda_metrics[-1][-1],
                                Llambda_metrics[-2][-1]
                                if len(Llambda_metrics) > 1
                                else Lgamma_metrics[-2][-1][-1],
                                Llambda_flat_iteration,
                            )
                        # logging.debug("update density weight %.3f ms" % ((time.time()-t2)*1000))

                        # update macro overlap weight
                        if Llambda_flat_iteration > 1 and params.macro_overlap_flag:
                            model.op_collections.update_macro_overlap_weight_op(
                                Llambda_metrics[-1][-1],
                                Llambda_metrics[-2][-1]
                                if len(Llambda_metrics) > 1
                                else Lgamma_metrics[-2][-1][-1],
                                Llambda_flat_iteration,
                            )

                        if Llambda_stop_criterion(
                            Lgamma_step, Llambda_density_weight_step, Llambda_metrics
                        ):
                            break

                        # for routability optimization
                        if (
                            params.routability_opt_flag
                            and num_area_adjust < params.max_num_area_adjust
                            and Llambda_metrics[-1][-1].overflow
                            < params.node_area_adjust_overflow
                        ):
                            content = (
                                "routability optimization round %d: adjust area flags = (%d, %d, %d)"
                                % (
                                    num_area_adjust,
                                    adjust_area_flag,
                                    adjust_route_area_flag,
                                    adjust_pin_area_flag,
                                )
                            )
                            pos = model.data_collections.pos[0]

                            route_utilization_map = None
                            pin_utilization_map = None
                            if adjust_route_area_flag:
                                if params.adjust_nctugr_area_flag:
                                    route_utilization_map = (
                                        model.op_collections.nctugr_congestion_map_op(
                                            pos
                                        )
                                    )
                                else:
                                    route_utilization_map = (
                                        model.op_collections.route_utilization_map_op(
                                            pos
                                        )
                                    )
                                if params.plot_flag:
                                    path = "%s/%s" % (
                                        params.result_dir,
                                        params.design_name(),
                                    )
                                    figname = "%s/plot/route%d.png" % (
                                        path,
                                        num_area_adjust,
                                    )
                                    os.system(
                                        "mkdir -p %s" % (os.path.dirname(figname))
                                    )
                                    plt.imsave(
                                        figname,
                                        route_utilization_map.data.cpu().numpy().T,
                                        origin="lower",
                                    )
                            if adjust_pin_area_flag:
                                pin_utilization_map = (
                                    model.op_collections.pin_utilization_map_op(pos)
                                )
                                if params.plot_flag:
                                    path = "%s/%s" % (
                                        params.result_dir,
                                        params.design_name(),
                                    )
                                    figname = "%s/plot/pin%d.png" % (
                                        path,
                                        num_area_adjust,
                                    )
                                    os.system(
                                        "mkdir -p %s" % (os.path.dirname(figname))
                                    )
                                    plt.imsave(
                                        figname,
                                        pin_utilization_map.data.cpu().numpy().T,
                                        origin="lower",
                                    )
                            (
                                adjust_area_flag,
                                adjust_route_area_flag,
                                adjust_pin_area_flag,
                            ) = model.op_collections.adjust_node_area_op(
                                pos, route_utilization_map, pin_utilization_map
                            )
                            content += " -> (%d, %d, %d)" % (
                                adjust_area_flag,
                                adjust_route_area_flag,
                                adjust_pin_area_flag,
                            )
                            logging.info(content)
                            if adjust_area_flag:
                                num_area_adjust += 1
                                # restart Llambda
                                model.op_collections.density_op.reset()
                                model.op_collections.density_overflow_op.reset()
                                model.op_collections.pin_utilization_map_op.reset()
                                model.initialize_density_weight(params, placedb)
                                model.density_weight.mul_(0.1 / params.density_weight)
                                logging.info(
                                    "density_weight = %.6E"
                                    % (model.density_weight.data)
                                )
                                # load state to restart the optimizer
                                optimizer.load_state_dict(initial_state)
                                # must after loading the state
                                initialize_learning_rate(pos)
                                # increase iterations of the sub problem to slow down the search
                                model.Lsub_iteration = model.routability_Lsub_iteration

                                # reset best metric
                                best_metric[0] = None
                                best_pos[0] = None
                                break

                    # gradually reduce gamma to tradeoff smoothness and accuracy
                    if (
                        len(placedb.regions) > 0
                        and Llambda_metrics[-1][-1].goverflow is not None
                    ):
                        model.op_collections.update_gamma_op(
                            Lgamma_step, Llambda_metrics[-1][-1].goverflow
                        )
                    elif (
                        len(placedb.regions) == 0
                        and Llambda_metrics[-1][-1].overflow is not None
                    ):
                        model.op_collections.update_gamma_op(
                            Lgamma_step, Llambda_metrics[-1][-1].overflow
                        )
                    else:
                        model.op_collections.precondition_op.set_overflow(
                            Llambda_metrics[-1][-1].overflow
                        )
                    if (
                        Lgamma_stop_criterion(Lgamma_step, Lgamma_metrics)
                        or stop_placement == 1
                    ):
                        break

                    # update learning rate
                    if optimizer_name.lower() in [
                        "sgd",
                        "adam",
                        "sgd_momentum",
                        "sgd_nesterov",
                        "cg",
                    ]:
                        if "learning_rate_decay" in global_place_params:
                            scheduler.step()
                            lrs.append(optimizer.param_groups[0]["lr"])

                # in case of divergence, use the best metric
                # last_metric = all_metrics[-1][-1][-1]
                # if (
                #     last_metric.overflow[-1] > max(params.stop_overflow, best_metric[0].overflow[-1])
                #     and last_metric.hpwl > best_metric[0].hpwl
                # ):
                #     all_metrics.append([best_metric])

                logging.info(
                    "optimizer %s takes %.3f seconds"
                    % (optimizer_name, time.time() - tt)
                )

            # recover node size and pin offset for legalization, since node size is adjusted in global placement
            if params.routability_opt_flag:
                with torch.no_grad():
                    # convert lower left to centers
                    self.pos[0][: placedb.num_movable_nodes].add_(
                        self.data_collections.node_size_x[: placedb.num_movable_nodes]
                        / 2
                    )
                    self.pos[0][
                        placedb.num_nodes : placedb.num_nodes
                        + placedb.num_movable_nodes
                    ].add_(
                        self.data_collections.node_size_y[: placedb.num_movable_nodes]
                        / 2
                    )
                    self.data_collections.node_size_x.copy_(
                        self.data_collections.original_node_size_x
                    )
                    self.data_collections.node_size_y.copy_(
                        self.data_collections.original_node_size_y
                    )
                    # use fixed centers as the anchor
                    self.pos[0][: placedb.num_movable_nodes].sub_(
                        self.data_collections.node_size_x[: placedb.num_movable_nodes]
                        / 2
                    )
                    self.pos[0][
                        placedb.num_nodes : placedb.num_nodes
                        + placedb.num_movable_nodes
                    ].sub_(
                        self.data_collections.node_size_y[: placedb.num_movable_nodes]
                        / 2
                    )
                    self.data_collections.pin_offset_x.copy_(
                        self.data_collections.original_pin_offset_x
                    )
                    self.data_collections.pin_offset_y.copy_(
                        self.data_collections.original_pin_offset_y
                    )

        if params.plot_flag:
            self.plot(params, placedb, 9999, self.pos[0].data.clone().cpu().numpy())

        # recover node sizes, pins shifts, and positions of macros
        if params.macro_halo_x >= 0 and params.macro_halo_y >= 0:
            with torch.no_grad():
                # node sizes
                self.data_collections.node_size_x[placedb.movable_macro_idx] -= (
                    2 * params.macro_halo_x
                )
                self.data_collections.node_size_y[placedb.movable_macro_idx] -= (
                    2 * params.macro_halo_y
                )
                # self.data_collections.node_size_x[placedb.fixed_macro_idx] -= (
                #     2 * params.macro_halo_x
                # )
                # self.data_collections.node_size_y[placedb.fixed_macro_idx] -= (
                #     2 * params.macro_halo_y
                # )

                # pin offsets
                self.data_collections.pin_offset_x[
                    placedb.movable_macro_pins
                ] -= params.macro_halo_x
                self.data_collections.pin_offset_y[
                    placedb.movable_macro_pins
                ] -= params.macro_halo_y
                # self.data_collections.pin_offset_x[
                #     placedb.fixed_macro_pins
                # ] -= params.macro_halo_x
                # self.data_collections.pin_offset_y[
                #     placedb.fixed_macro_pins
                # ] -= params.macro_halo_y

                # macro locations
                self.pos[0][placedb.movable_slice][
                    placedb.movable_macro_mask
                ] += params.macro_halo_x
                self.pos[0][
                    placedb.num_nodes : placedb.num_nodes + placedb.num_movable_nodes
                ][placedb.movable_macro_mask] += params.macro_halo_y
                self.pos[0][placedb.fixed_slice][
                    placedb.fixed_macro_mask
                ] += params.macro_halo_x
                self.pos[0][
                    placedb.num_nodes
                    + placedb.num_movable_nodes : placedb.num_nodes
                    + placedb.num_movable_nodes
                    + placedb.num_terminals
                ][placedb.fixed_macro_mask] += params.macro_halo_y

        # rescale everything
        cur_scale_factor = self.data_collections.fp_info.scale_factor
        gcd_site_scale_factor = 1 / math.gcd(
            placedb.pydb.site_width, placedb.pydb.row_height
        )
        if cur_scale_factor != gcd_site_scale_factor:
            logging.warn(
                f"Rescaling by GCD(site_width, row_height) = {gcd_site_scale_factor} before legalization and detailed placement"
            )
            params.scale_factor = gcd_site_scale_factor
            rescale_factor = gcd_site_scale_factor / cur_scale_factor
            with torch.no_grad():
                self.pos[0].mul_(rescale_factor).round_()
                self.data_collections.node_size_x.mul_(rescale_factor).round_()
                self.data_collections.node_size_y.mul_(rescale_factor).round_()
                self.data_collections.flat_region_boxes.mul_(rescale_factor).round_()
                self.data_collections.pin_offset_x.mul_(rescale_factor)
                self.data_collections.pin_offset_y.mul_(rescale_factor)
                # self.data_collections.node_areas.mul_(rescale_factor * rescale_factor)
                self.data_collections.fp_info.scale(rescale_factor)
                self.data_collections.fp_info.scale_factor = gcd_site_scale_factor
                # TODO: rescale fence regions

        # dump global placement solution for legalization
        if params.dump_global_place_solution_flag:
            self.dump(
                params,
                placedb,
                self.pos[0].cpu(),
                "%s.lg.pklz" % (params.design_name()),
            )

        # process metrics
        flatten = lambda l: sum(map(flatten, l), []) if isinstance(l, list) else [l]
        metrics = flatten(all_metrics)
        objectives = [metric.objective.data.item() for metric in metrics]
        hpwls = [metric.hpwl.data.item() for metric in metrics]
        overflows = [metric.overflow.data.item() for metric in metrics]
        densities = [metric.max_density.data.item() for metric in metrics]
        processed_metrics = {
            "objective": objectives,
            "hpwl": hpwls,
            "overflow": overflows,
            "density": densities,
        }

        # plot losses
        if params.plot_flag:
            self.plot(
                params, placedb, iteration, self.pos[0].data.clone().cpu().numpy()
            )
            epochs = np.arange(len(objectives))
            num_subplots = 4
            if len(lrs) >= len(epochs):
                lrs = [float(lr) for lr in lrs]
                num_subplots += 1
            if params.macro_overlap_flag:
                macro_overlaps = [
                    metric.macro_overlap.data.item() for metric in metrics
                ]
                num_subplots += 1
            fig, axis = plt.subplots(1, num_subplots, figsize=(5 * num_subplots, 5))
            axis[0].plot(epochs, np.log10(objectives), color="red")
            axis[0].set_title("Obj")
            axis[1].plot(epochs, np.log10(hpwls), color="blue")
            axis[1].set_title("HPWL")
            axis[2].plot(epochs, overflows, color="green")
            axis[2].set_title("Ovflw")
            axis[3].plot(epochs, densities, color="orange")
            axis[3].set_title("MaxDens")
            plotted = 4
            if len(lrs) >= len(epochs):
                axis[plotted].plot(
                    epochs, np.log10(lrs[0 : len(epochs)]), color="black"
                )
                axis[plotted].set_title("Lr")
                plotted += 1
            if params.macro_overlap_flag:
                axis[plotted].plot(epochs, macro_overlaps, color="purple")
                axis[plotted].set_title("MacroOvlp")
            fig.tight_layout(pad=2.0)
            path = "%s/%s" % (params.result_dir, params.design_name())
            figname = os.path.join(path, f"{params.design_name()}.loss.png")
            plt.savefig(figname)

        last_metric = copy.deepcopy(all_metrics)
        for idx in [-1, -1, -1]:
            try:
                last_metric = last_metric[idx]
            except IndexError:
                last_metric = False
                break

        if not last_metric:
            cur_metric = EvalMetrics.EvalMetrics(iteration)
            all_metrics.append(cur_metric)
            cur_metric.evaluate(
                placedb, {"hpwl": self.op_collections.hpwl_op}, self.pos[0]
            )
            logging.info(cur_metric)

        # in case of significant divergence, no need to run legalizer
        if last_metric and (
            last_metric.overflow[-1] > params.stop_overflow
            or torch.isinf(last_metric.objective)
            or torch.isnan(last_metric.objective)
        ):
            logging.warn(
                "overflow is significant %.3f or hpwl is infinity or nan, skip legalization and detail placement steps"
                % (last_metric.overflow[-1])
            )
            self.plot(params, placedb, 9999, self.pos[0].data.clone().cpu().numpy())
            return float("inf"), float("inf"), processed_metrics

        # legalization
        if params.legalize_flag:
            tt = time.time()
            self.pos[0].data.copy_(self.op_collections.legalize_op(self.pos[0]))
            logging.info("legalization takes %.3f seconds" % (time.time() - tt))
            cur_metric = EvalMetrics.EvalMetrics(iteration)
            all_metrics.append(cur_metric)
            cur_metric.evaluate(
                placedb, {"hpwl": self.op_collections.hpwl_op}, self.pos[0]
            )
            logging.info(cur_metric)
            iteration += 1

        # plot placement
        if params.plot_flag:
            self.plot(
                params, placedb, iteration, self.pos[0].data.clone().cpu().numpy()
            )

        # dump legalization solution for detailed placement
        if params.dump_legalize_solution_flag:
            self.dump(
                params,
                placedb,
                self.pos[0].cpu(),
                "%s.dp.pklz" % (params.design_name()),
            )

        # detailed placement
        if params.detailed_place_flag:
            tt = time.time()
            self.pos[0].data.copy_(self.op_collections.detailed_place_op(self.pos[0]))
            macro_orients = self.op_collections.macro_refinement_op(self.pos[0])
            logging.info("detailed placement takes %.3f seconds" % (time.time() - tt))
            cur_metric = EvalMetrics.EvalMetrics(iteration)
            all_metrics.append(cur_metric)
            cur_metric.evaluate(
                placedb, {"hpwl": self.op_collections.hpwl_op}, self.pos[0]
            )
            logging.info(cur_metric)
            iteration += 1

        # save results
        cur_pos = self.pos[0].data.clone().cpu().numpy()
        # apply position solution/ will also update movable node orientations based on row
        placedb.apply(
            params,
            cur_pos[0 : placedb.num_movable_nodes],
            cur_pos[placedb.num_nodes : placedb.num_nodes + placedb.num_movable_nodes],
        )

        # apply macro orientations solution
        if params.detailed_place_flag:
            orients_map = {
                "N": 0,
                "S": 1,
                "W": 2,
                "E": 3,
                "FN": 4,
                "FS": 5,
                "FW": 6,
                "FE": 7,
                "UNKNOWN": 8,
            }
            for macro, orient in macro_orients:
                placedb.rawdb.setNodeOrient(
                    int(macro), place_io_cpp.OrientEnum.OrientType(orients_map[orient])
                )

        # update pin offsets of std cells
        # assume rows are FS = 0, N = 1, FS, ...
        # cur_orient = torch.from_numpy(
        #     np.where(placedb.node_orient == b"N", 1, 0)[: placedb.num_movable_nodes]
        # ).to(self.pos[0].device)
        # new_orient = (
        #     torch.div(
        #         self.pos[0][
        #             placedb.num_nodes : placedb.num_nodes + placedb.num_movable_nodes
        #         ],
        #         self.data_collections.fp_info.row_height,
        #         rounding_mode="floor",
        #     )
        #     % 2
        # )
        # flips = (
        #     (cur_orient != new_orient) & ~self.data_collections.movable_macro_mask
        # ).nonzero().view(-1)
        # self.data_collections.pin_offset_y[flips] = (
        #     self.data_collections.fp_info.row_height
        #     - self.data_collections.pin_offset_y[flips]
        # )

        # reset net weights
        self.data_collections.net_weights.fill_(1.0)

        # plot placement
        # if params.plot_flag:
        self.plot(params, placedb, iteration, cur_pos)

        # run RSMT
        with torch.no_grad():
            tt = time.time()
            rsmt_wl = self.op_collections.rsmt_wl_op(self.pos[0])
            logging.info("rsmt computation takes %.3f seconds" % (time.time() - tt))
            logging.info("flute rsmt %.6E" % rsmt_wl)

        # get HPWL
        with torch.no_grad():
            hpwl = self.op_collections.hpwl_op(self.pos[0])
            logging.info("unweighted hpwl %.6E" % hpwl)

        # save nets degree, RSMT, HPWL
        # with torch.no_grad():
        #     degrees = torch.from_numpy(np.ediff1d(placedb.flat_net2pin_start_map))
        #     mask = torch.logical_and(2 <= degrees, degrees < params.ignore_net_degree)
        #     degrees = degrees[mask].long()
        #     steiners = self.op_collections.rsmt_wl_op(self.pos[0], False)[mask]
        #     wirelengths = (
        #         self.op_collections.hpwl_op(self.pos[0], False)
        #         .cpu()
        #         .detach()[mask]
        #     )
        #     weights = steiners / wirelengths
        #     # get new RISA weights
        #     degrees, indices = torch.sort(degrees)
        #     weights = weights[indices]
        #     c = torch.stack((degrees, weights))
        #     idxs, vals = torch.unique(c[0, :], return_counts=True)
        #     vs = torch.split_with_sizes(c[1, :], tuple(vals))
        #     weights_dict = {int(k.item()): float(v.mean()) for k, v in zip(idxs, vs)}
        #     path = "%s/%s" % (params.result_dir, params.design_name())
        #     with open("%s/risa_weights.pkl" % path, "wb") as f:
        #         pickle.dump(weights_dict, f)

        return float(rsmt_wl), float(hpwl), processed_metrics
